import os
import argparse
import numpy as np
from easydict import EasyDict
import wandb
import torch
import warnings

from utils.set_seed import set_seed
from utils.read_config_file import read_config_file
from utils.temperature_scaling import set_temperature_scaling
from utils.print_final_results import print_final_results
from data.get_dataloaders_cifar import get_dataloaders_cifar
from data.get_dataloaders_prostate_mri import get_dataloaders_prostate_mri
from data.get_dataloaders_tiny_imagenet import get_dataloaders_tiny_imagenet
from data.get_dataloaders_medmnist import get_dataloaders_medmnist
from src.trainer import Trainer
from src.eval import get_new_results


def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Running on device: ", device)
    # Set parameters
    parser = argparse.ArgumentParser(description="Run train and/or test.")
    parser.add_argument(
        "--project_name",
        type=str,
        default="calibration-nn-classification",
        help="Whether run training.",
    )
    parser.add_argument(
        "--model_name",
        type=str,
        help="Model name.",
    )
    parser.add_argument(
        "--train_mode",
        type=int,
        default=1,
        help="Whether run training.",  # CHANGE BACK TO 1
    )
    parser.add_argument(
        "--eval_mode", type=int, default=1, help="Whether run evaluation."
    )
    parser.add_argument(
        "--base_config_file", type=str, help="Settings for dataset and model."
    )
    parser.add_argument(
        "--loss_config_file",
        type=str,
        help="Settings for the loss function to be used.",
    )
    parser.add_argument(
        "--resume_training",
        type=int,
        default=0,
        help="Whether resume a pre-trained model.",
    )
    parser.add_argument(
        "--use_pretrained_model",
        type=int,
        default=0,
        help="Whether start a whole new training from pretraied model.",
    )
    parser.add_argument(
        "--best_val_auc_runs",
        type=list,
        default=[],
        help="List to track val AUC over runs.",
    )
    parser.add_argument(
        "--best_val_acc_runs",
        type=list,
        default=[],
        help="List to track val accuracy over runs.",
    )
    parser.add_argument(
        "--best_val_ece_runs",
        type=list,
        default=[],
        help="List to track val ece over runs.",
    )
    parser.add_argument(
        "--use_temperature_scaling",
        type=int,
        default=0,
        help="Whether to set a temperature for scaling.",
    )
    parser.add_argument(
        "--test_corruptions",
        type=int,
        default=0,
        help="Whether to test on a corrupted dataset.",
    )
    parser.add_argument(
        "--levels_corruption",
        type=int,
        default=5,
        help="Number of levels of corruption.",
    )
    parser.add_argument(
        "--corruption_type",
        default="gaussian_noise",
        type=str,
        help="Type of corruption.",
    )
    parser.add_argument(
        "--num_thresholds",
        type=int,
        help="How many thresholds to use to compute the ROC plots.",
    )
    parser.add_argument(
        "--gamma_FL",
        type=float,
        default=3.0,
        help="Parameter for focal loss.",
    )
    parser.add_argument(
        "--lamda",
        type=float,
        default=1.0,
        help="Weight for AUC loss.",
    )
    parser.add_argument(
        "--cudnn_benchmark",
        type=bool,
        default=True,
        help="Set cudnn benchmark on (1) or off (0) (default is on).",
    )
    parser.add_argument(
        "--different_lr",
        type=int,
        default=0,
        help="Whether to use two different LR for primary and secondary loss function.",
    )
    parser.add_argument(
        "--use_scheduler",
        type=int,
        default=1,
        help="Whether to use a scheduler to train the network.",
    )
    parser.add_argument(
        "--use_scheduler_secondary",
        type=int,
        default=1,
        help="Whether to use a scheduler to train the network for the secondary loss in case of different LR.",
    )
    parser.add_argument(
        "--resize_medmnist",
        type=int,
        default=0,
        help="Whether to resize MedMNIST to 224x224.",
    )
    parser.add_argument(
        "--plot_together",
        type=int,
        default=0,
        help="Whether to plot together.",
    )
    parser.add_argument(
        "--ts_pre",
        type=float,
        default=0.0,
        help="Temperature value of pretrained moel.",
    )
    
    settings = vars(parser.parse_args())
    settings = read_config_file("configs", settings["paths_config_file"], settings)
    settings = read_config_file(
        settings["base_config_path"], settings["base_config_file"], settings
    )
    settings = read_config_file(
        settings["loss_config_path"], settings["loss_config_file"], settings
    )
    settings = EasyDict(settings)

    # Setup other parameters: device, directory for checkpoints and plots of this model
    settings.device = device
    settings.checkpoint_dir = os.path.join(
        settings.checkpoints_path,
        settings.project_name,
        settings.dataset,
        settings.net_type,
        str(settings.batch_size),
        settings.loss_type,
        settings.model_name,
    )
    settings.plots_dir = os.path.join(
        "../plots",
        settings.project_name,
        settings.dataset,
        settings.net_type,
        str(settings.batch_size),
        settings.loss_type,
        settings.model_name,
    )

    if settings.use_pretrained_model == 1:
        settings.checkpoint_pretrained_dir = os.path.join(
            settings.checkpoints_path,
            settings.project_name,
            settings.dataset,
            settings.net_type,
            str(settings.batch_size),
            settings.loss_type_pretrained,
            settings.model_pretrained_name,
        )
    if not os.path.exists(settings.plots_dir):
        os.makedirs(settings.plots_dir)
    print("Saving checkpoint at", settings.checkpoint_dir)

    # Run trainig and test 5 times, with fixed seeds for reproducibility
    seeds = np.arange(0, 3)
    os.environ["WANDB_SILENT"] = "true"
    test_acc_runs_best_ece = []
    test_em_ece_runs_best_ece = []
    test_auc_runs_best_ece = []
    test_acc_runs_best_acc = []
    test_em_ece_runs_best_acc = []
    test_auc_runs_best_acc = []
    test_acc_runs_best_auc = []
    test_em_ece_runs_best_auc = []
    test_auc_runs_best_auc = []


    os.environ["WANDB_START_METHOD"] = "thread"

    for seed in seeds:
        set_seed(seed)
        settings.seed = seed

        # Set the correct wandb project name
        project_name_wandb = "calibration-{}-{}-{}".format(
            settings.dataset,
            settings.net_type,
            str(settings.batch_size),
        )


        if settings.use_temperature_scaling == 1 and settings.test_corruptions == 0:
            project_name_wandb = "calibration-{}-{}-TS".format(
                settings.dataset, settings.net_type  # , str(settings.base_lr)
            )
        elif settings.use_temperature_scaling == 0 and settings.test_corruptions == 1:
            project_name_wandb = "calibration-{}-{}-corrupted".format(
                settings.dataset, settings.net_type
            )
        elif settings.use_temperature_scaling == 1 and settings.test_corruptions == 1:
            project_name_wandb = "calibration-{}-{}-corrupted-TS".format(
                settings.dataset, settings.net_type
            )
        with wandb.init(
            project=project_name_wandb,
            config=settings,
            dir=settings.dir_wandb,
        ):
            # Get dataset loaders
            if "cifar" in settings.dataset:
                train_loader, val_loader, test_loader = get_dataloaders_cifar(settings)
            elif settings.dataset == "tiny-imagenet":
                train_loader, val_loader, test_loader = get_dataloaders_tiny_imagenet(
                    settings
                )
            elif settings.dataset == "prostate_mri":
                train_loader, val_loader, test_loader = get_dataloaders_prostate_mri(
                    settings
                )
            elif "mnist" in settings.dataset:
                train_loader, val_loader, test_loader = get_dataloaders_medmnist(
                    settings
                )

            else:
                warnings.warn("Dataset is not listed.")

            print(
                "%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%  Run number {:2d}  %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%".format(
                    settings.seed
                )
            )
            # Train model
            wandb.run.name = settings.model_name + "/seed-{:2d}".format(settings.seed)
            if settings.train_mode == 1:

                trainer = Trainer(settings, train_loader, val_loader)
                trainer.train()

            # Test best EM-ECE model
            checkpoint_file = "{}/{}_{:02d}_best_ece.pth".format(
                settings.checkpoint_dir, settings.model_name, settings.seed
            )
            if settings.use_temperature_scaling == 1:
                set_temperature_scaling(val_loader, checkpoint_file, settings)
            if settings.test_corruptions == 0:
                get_new_results(
                    settings,
                    checkpoint_file,
                    test_loader,
                    test_em_ece_runs_best_ece,
                    test_acc_runs_best_ece,
                    test_auc_runs_best_ece,
                )

            # Test best accuracy model
            checkpoint_file = "{}/{}_{:02d}_best_acc.pth".format(
                settings.checkpoint_dir, settings.model_name, settings.seed
            )
            if settings.use_temperature_scaling == 1:
                set_temperature_scaling(val_loader, checkpoint_file, settings)
            get_new_results(
                settings,
                checkpoint_file,
                test_loader,
                test_em_ece_runs_best_acc,
                test_acc_runs_best_acc,
                test_auc_runs_best_acc,
            )

            # Test best AUC model
            checkpoint_file = "{}/{}_{:02d}_best_auc.pth".format(
                settings.checkpoint_dir, settings.model_name, settings.seed
            )
            if settings.use_temperature_scaling == 1:
                set_temperature_scaling(val_loader, checkpoint_file, settings)
            get_new_results(
                settings,
                checkpoint_file,
                test_loader,
                test_em_ece_runs_best_auc,
                test_acc_runs_best_auc,
                test_auc_runs_best_auc,
            )

    # Calculate statistics over multiple runs
    if seeds.size > 1 and settings.test_corruptions == 0:
        print_final_results(
            settings,
            test_acc_runs_best_ece,
            test_em_ece_runs_best_ece,
            test_auc_runs_best_ece,
            test_acc_runs_best_acc,
            test_em_ece_runs_best_acc,
            test_auc_runs_best_acc,
            test_acc_runs_best_auc,
            test_em_ece_runs_best_auc,
            test_auc_runs_best_auc,
        )


if __name__ == "__main__":

    main()
